{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-16:39:52.721.890 [mindspore/_check_version.py:207] MindSpore version 1.1.1 and \"te\" wheel package version 1.0 does not match, reference to the match info on: https://www.mindspore.cn/install\n",
      "MindSpore version 1.1.1 and \"topi\" wheel package version 0.6.0 does not match, reference to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-16:39:53.183.964 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import collections\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "import mindspore.common.dtype as mstype\n",
    "from mindspore import context\n",
    "from mindspore import log as logger\n",
    "from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell\n",
    "from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum\n",
    "from mindspore.common.tensor import Tensor\n",
    "from mindspore.train.model import Model\n",
    "from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor\n",
    "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
    "\n",
    "from src.dataset import create_squad_dataset\n",
    "from src.bert_for_finetune import BertSquadCell, BertSquad\n",
    "from src.finetune_eval_config import optimizer_cfg, bert_net_cfg\n",
    "from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate\n",
    "\n",
    "_cur_dir = os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_train(dataset=None, network=None, load_checkpoint_path=\"\", save_checkpoint_path=\"\", epoch_num=1):\n",
    "    \"\"\" do train \"\"\"\n",
    "    if load_checkpoint_path == \"\":\n",
    "        raise ValueError(\"Pretrain model missed, finetune task must load pretrain model!\")\n",
    "    steps_per_epoch = dataset.get_dataset_size()\n",
    "    # optimizer\n",
    "    if optimizer_cfg.optimizer == 'AdamWeightDecay':\n",
    "        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,\n",
    "                                       end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,\n",
    "                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),\n",
    "                                       decay_steps=steps_per_epoch * epoch_num,\n",
    "                                       power=optimizer_cfg.AdamWeightDecay.power)\n",
    "        params = network.trainable_params()\n",
    "        decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))\n",
    "        other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))\n",
    "        group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},\n",
    "                        {'params': other_params, 'weight_decay': 0.0}]\n",
    "\n",
    "        optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)\n",
    "    elif optimizer_cfg.optimizer == 'Lamb':\n",
    "        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,\n",
    "                                       end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,\n",
    "                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),\n",
    "                                       decay_steps=steps_per_epoch * epoch_num,\n",
    "                                       power=optimizer_cfg.Lamb.power)\n",
    "        optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)\n",
    "    elif optimizer_cfg.optimizer == 'Momentum':\n",
    "        optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,\n",
    "                             momentum=optimizer_cfg.Momentum.momentum)\n",
    "    else:\n",
    "        raise Exception(\"Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]\")\n",
    "\n",
    "    # load checkpoint into network\n",
    "    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)\n",
    "    ckpoint_cb = ModelCheckpoint(prefix=\"squad\",\n",
    "                                 directory=None if save_checkpoint_path == \"\" else save_checkpoint_path,\n",
    "                                 config=ckpt_config)\n",
    "    param_dict = load_checkpoint(load_checkpoint_path)\n",
    "    load_param_into_net(network, param_dict)\n",
    "\n",
    "    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)\n",
    "    netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell)\n",
    "    model = Model(netwithgrads)\n",
    "    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]\n",
    "    model.train(epoch_num, dataset, callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_eval(dataset=None, load_checkpoint_path=\"\", eval_batch_size=1):\n",
    "    \"\"\" do eval \"\"\"\n",
    "    if load_checkpoint_path == \"\":\n",
    "        raise ValueError(\"Finetune model missed, evaluation task must load finetune model!\")\n",
    "    net = BertSquad(bert_net_cfg, False, 2)\n",
    "    net.set_train(False)\n",
    "    param_dict = load_checkpoint(load_checkpoint_path)\n",
    "    load_param_into_net(net, param_dict)\n",
    "    model = Model(net)\n",
    "    output = []\n",
    "    RawResult = collections.namedtuple(\"RawResult\", [\"unique_id\", \"start_logits\", \"end_logits\"])\n",
    "    columns_list = [\"input_ids\", \"input_mask\", \"segment_ids\", \"unique_ids\"]\n",
    "    for data in dataset.create_dict_iterator(num_epochs=1):\n",
    "        input_data = []\n",
    "        for i in columns_list:\n",
    "            input_data.append(data[i])\n",
    "        input_ids, input_mask, segment_ids, unique_ids = input_data\n",
    "        start_positions = Tensor([1], mstype.float32)\n",
    "        end_positions = Tensor([1], mstype.float32)\n",
    "        is_impossible = Tensor([1], mstype.float32)\n",
    "        logits = model.predict(input_ids, input_mask, segment_ids, start_positions,\n",
    "                               end_positions, unique_ids, is_impossible)\n",
    "        ids = logits[0].asnumpy()\n",
    "        start = logits[1].asnumpy()\n",
    "        end = logits[2].asnumpy()\n",
    "\n",
    "        for i in range(eval_batch_size):\n",
    "            unique_id = int(ids[i])\n",
    "            start_logits = [float(x) for x in start[i].flat]\n",
    "            end_logits = [float(x) for x in end[i].flat]\n",
    "            output.append(RawResult(\n",
    "                unique_id=unique_id,\n",
    "                start_logits=start_logits,\n",
    "                end_logits=end_logits))\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "args_opt = edict({\n",
    "    \"device_target\":\"Ascend\",\n",
    "    \"do_train\":\"true\",\n",
    "    \"do_eval\":\"true\",\n",
    "    \"epoch_num\":3,\n",
    "    \"num_class\":2,\n",
    "    \"train_data_shuffle\":\"false\",\n",
    "    \"eval_data_shuffle\":\"false\",\n",
    "    \"train_batch_size\":32,\n",
    "    \"eval_batch_size\":1,\n",
    "    \"vocab_file_path\":\"./squad/vocab_bert_large_en.txt\",\n",
    "    \"save_finetune_checkpoint_path\":\"\",\n",
    "    \"load_pretrain_checkpoint_path\":\"./squad/bert_converted.ckpt\",\n",
    "    \"load_finetune_checkpoint_path\":\"./squad-3_2745.ckpt\",\n",
    "    \"train_data_file_path\":\"./squad/train.tf_record\",\n",
    "    \"eval_json_path\":\"./squad/dev-v1.1.json\",\n",
    "    \"schema_file_path\":\"\"\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:27.589.156 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:28.339.415 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:29.915.23 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:29.855.040 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:30.620.489 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:31.374.410 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:32.131.223 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:32.889.806 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:33.657.877 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:34.417.506 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:35.176.229 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:10:35.936.452 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:16.914.060 [mindspore/dataset/engine/datasets.py:1865] Repeat is located before batch, data from two epochs can be batched together.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:20.380.04 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:20.800.065 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:21.581.869 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:22.350.533 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:23.122.902 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:23.890.777 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:24.671.108 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:25.434.968 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:26.209.131 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:26.970.451 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:27.750.138 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n",
      "[WARNING] ME(184813:281473105954352,MainProcess):2021-03-01-17:18:28.512.393 [mindspore/ops/operations/nn_ops.py:2979] WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"exact_match\": 80.51087984862819, \"f1\": 87.94542868231342}\n"
     ]
    }
   ],
   "source": [
    "target = args_opt.device_target\n",
    "if target == \"Ascend\":\n",
    "    context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=1)\n",
    "elif target == \"GPU\":\n",
    "    context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\")\n",
    "    if bert_net_cfg.compute_type != mstype.float32:\n",
    "        logger.warning('GPU only support fp32 temporarily, run with fp32.')\n",
    "        bert_net_cfg.compute_type = mstype.float32\n",
    "else:\n",
    "    raise Exception(\"Target error, GPU or Ascend is supported.\")\n",
    "\n",
    "netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)\n",
    "\n",
    "if args_opt.do_train.lower() == \"true\":\n",
    "    ds = create_squad_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,\n",
    "                              data_file_path=args_opt.train_data_file_path,\n",
    "                              schema_file_path=args_opt.schema_file_path,\n",
    "                              do_shuffle=(args_opt.train_data_shuffle.lower() == \"true\"))\n",
    "    do_train(ds, netwithloss, args_opt.load_pretrain_checkpoint_path, args_opt.save_finetune_checkpoint_path, args_opt.epoch_num)\n",
    "    if args_opt.do_eval.lower() == \"true\":\n",
    "        if args_opt.save_finetune_checkpoint_path == \"\":\n",
    "            load_finetune_checkpoint_dir = _cur_dir\n",
    "        else:\n",
    "            load_finetune_checkpoint_dir = make_directory(args_opt.save_finetune_checkpoint_path)\n",
    "        load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir,\n",
    "                                                       ds.get_dataset_size(), args_opt.epoch_num, \"squad\")\n",
    "\n",
    "if args_opt.do_eval.lower() == \"true\":\n",
    "    from src import tokenization\n",
    "    from src.create_squad_data import read_squad_examples, convert_examples_to_features\n",
    "    from src.squad_get_predictions import write_predictions\n",
    "    from src.squad_postprocess import SQuad_postprocess\n",
    "    tokenizer = tokenization.FullTokenizer(vocab_file=args_opt.vocab_file_path, do_lower_case=True)\n",
    "    eval_examples = read_squad_examples(args_opt.eval_json_path, False)\n",
    "    eval_features = convert_examples_to_features(\n",
    "        examples=eval_examples,\n",
    "        tokenizer=tokenizer,\n",
    "        max_seq_length=bert_net_cfg.seq_length,\n",
    "        doc_stride=128,\n",
    "        max_query_length=64,\n",
    "        is_training=False,\n",
    "        output_fn=None,\n",
    "        vocab_file=args_opt.vocab_file_path)\n",
    "    ds = create_squad_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,\n",
    "                              data_file_path=eval_features,\n",
    "                              schema_file_path=args_opt.schema_file_path, is_training=False,\n",
    "                              do_shuffle=(args_opt.eval_data_shuffle.lower() == \"true\"))\n",
    "    outputs = do_eval(ds, args_opt.load_finetune_checkpoint_path, args_opt.eval_batch_size)\n",
    "    all_predictions = write_predictions(eval_examples, eval_features, outputs, 20, 30, True)\n",
    "    SQuad_postprocess(args_opt.eval_json_path, all_predictions, output_metrics=\"output.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
